!==============================================================================
! Copyright (C) 2010, University Corporation for Atmospheric Research,
!                     Colorado State University,
!                     Los Alamos National Security, LLC,
!                     United States Department of Energy
!
! All rights reserved.  See ../COPYING for copyright details
!==============================================================================

module matrix_mod

! !DESCRIPTION:
!  This module contains routines multiplying a 
!  compressed sparse row matrix and a vector
!  the linear data structure in the solver.
!
! !REVISION HISTORY:
!  CVS: $Id: matrix_mod.F90,v 1.10 2006/05/30 16:02:00 dennis Exp $
!  CVS: $Name:  $

! !USES:

   use kinds_mod, only: i4, r8
   use blocks, only: nx_block,ny_block, get_block_parameter
   use domain, only: blocks_tropic, nblocks_tropic
   use domain_size, only: max_blocks_tropic
   use linear, only: max_linear, ldof

   implicit none
   private

! !PUBLIC TYPES:

    type, public :: Matrix_t
       integer(i4) ::  &
	n,  	&  ! The order of the matrix
	nz, 	&  ! The number of non-zero elements in CSR format
	maxNZ, 	&  ! The maximum possible number of non-zero elements
	nz_essl    ! The number of non-zero elements in ESSL format
       real(r8), dimension(:),pointer :: Mat ! The non-zero values for standard CSR
#ifdef _USEESSL
       real(r8), dimension(:,:),pointer :: ac  ! The non-zero values for ESSL support
#endif
       integer(i4), dimension(:),pointer :: Ja ! column index
       integer(i4), dimension(:),pointer   :: Ia ! row index
#ifdef _USEESSL
       integer(i4), dimension(:,:) :: ka ! stencil index for ESSL routines
#endif
    end type


! !PUBLIC MEMBER FUNCTIONS:

    public :: ConvertStencil
    public :: matvec
    public :: initMatrix

!-----------------------------------------------------------------------
!
!  module variables
!
!-----------------------------------------------------------------------

! !PUBLIC DATA MEMBERS:

    type (Matrix_t), public :: A   ! The matrix for the conjugate gradient solver

!EOP
!EOC
!***********************************************************************

contains

  subroutine initMatrix()
      
      allocate(A%Mat(9*max_linear))
      allocate(A%Ja(9*max_linear))
      allocate(A%Ia(9*max_linear))
#ifdef _USEESSL
      allocate(A%ac(max_linear,9))
      allocate(A%ka(max_linear,9))
#endif

  end subroutine initMatrix 
!***********************************************************************
!BOP
! !IROUTINE: ConvertStencil
! !INTERFACE:

 subroutine ConvertStencil(Mat,A0,AN,AE,ANE)
  
! !DESCRIPTION:
!  This subroutine forms the 9-point stencil into a matrix for the conjugate 
!  gradient solver.  Note that only the diagonal A0 is time dependent, the 
!  non-zero pattern is static.  This subroutine could therefore be simplified. 
!
! !REVISION HISTORY:
!  same as module

! !INPUT PARAMETERS:

   real (r8), intent(in),  &
	dimension(nx_block,ny_block,max_blocks_tropic) :: &
	A0,AN,AE,ANE	! barotropic (9pt) operator coefficients

! !OUTPUT PARAMETERS:

   type (Matrix_t), intent(inout) :: Mat  ! The matrix equiv of the 9-pt stencil

!EOP
!BOC
!-----------------------------------------------------------------------
!
!  local variables
!
!-----------------------------------------------------------------------

   integer(i4) :: nnz  ! The number of non-zero elements

   integer(i4) ::   	&  ! Some loop tempories 
	iblock, i, ii, j, 	&  !
	ib, ie, jb, je, istart,npoints     !
			
   integer(i4) ::  		&  ! indices used for the 9-point 
	i_j,i_jm1,i_jp1,im1_j,ip1_j, 	&  ! stencil 
        im1_jm1,ip1_jp1,im1_jp1,ip1_jm1
  
   integer(i4),allocatable, dimension(:)  ::  &
	 nz_per_row  			! number of non-zeros per row

   real(r8) :: val			! temporary non-zero value 

!-----------------------------------------------------------------------

   !--------------------
   ! Zero out the matrix  
   !--------------------
   Mat%Mat=0.0D0
   Mat%Ja=0
   Mat%Ia=0
#ifdef _USEESSL
   Mat%ac=0.0D0
   Mat%ka=0
#endif
   allocate(nz_per_row(Mat%n))
   nz_per_row = 0

   !---------------------------------------------
   ! Loop through each block and form the matrix 
   ! equivalent of the 9-point stencil 
   !---------------------------------------------
   ii=1
   do iblock=1,nblocks_tropic
      call get_block_parameter(blocks_tropic(iblock),npoints=npoints,ib=ib,ie=ie,jb=jb,je=je)
      if(npoints >0) then
      do j=jb,je
      do i=ib,ie
         i_j     = ldof(i  , j  , iblock)
         if(i_j .gt. 0) then
	    !-------------------------------------------
	    ! Only enter there if this is an ocean point 
	    !-------------------------------------------
            istart = ii
            i_jp1   = ldof(i  , j+1, iblock)
            i_jm1   = ldof(i  , j-1, iblock)
            ip1_j   = ldof(i+1, j  , iblock)
            im1_j   = ldof(i-1, j  , iblock)
            ip1_jp1 = ldof(i+1, j+1, iblock)
            ip1_jm1 = ldof(i+1, j-1, iblock)
            im1_jp1 = ldof(i-1, j+1, iblock)
            im1_jm1 = ldof(i-1, j-1, iblock)
            !==============================
            ! Construct the 9-point stencil
            !==============================

            !========
            ! A0(i,j)
            !========
            Mat%Ja(ii) = i_j
            val       = A0(i,j,iblock)
            Mat%Mat(ii)   = val
            if(Mat%Ja(ii) > 0) ii=ii+1

            !========
            ! AN(i,j)
            !========
            Mat%Ja(ii) = i_jp1
            val       = AN(i,j,iblock)
            Mat%Mat(ii)   = val
            if(Mat%Ja(ii) > 0) ii=ii+1

            !========
            ! AN(i,j-1)
            !========
            Mat%Ja(ii) = i_jm1
            val = AN(i,j-1,iblock)
            Mat%Mat(ii)   = val
            if(Mat%Ja(ii) > 0) ii=ii+1

            !========
            ! AE(i,j)
            !========
            Mat%Ja(ii) = ip1_j
            val = AE(i,j,iblock)
            Mat%Mat(ii)   = val
            if(Mat%Ja(ii) > 0) ii=ii+1


            !========
            ! AE(i-1,j)
            !========
            Mat%Ja(ii) = im1_j
            val = AE(i-1,j,iblock)
            Mat%Mat(ii)   = val
            if(Mat%Ja(ii) > 0) ii=ii+1

           !========
           ! ANE(i,j)
           !========
           Mat%Ja(ii) = ip1_jp1
           val = ANE(i,j,iblock)
           Mat%Mat(ii)   = val
           if(Mat%Ja(ii) > 0) ii=ii+1

           !========
           ! ANE(i,j-1)
           !========
           Mat%Ja(ii) = ip1_jm1
           val = ANE(i,j-1,iblock)
           Mat%Mat(ii)   = val
           if(Mat%Ja(ii) > 0) ii=ii+1

           !========
           ! ANE(i-1,j)
           !========
           Mat%Ja(ii) = im1_jp1
           val = ANE(i-1,j,iblock)
           Mat%Mat(ii)   = val
           if(Mat%Ja(ii) > 0) ii=ii+1

           !========
           ! ANE(i-1,j-1)
           !========
           Mat%Ja(ii) = im1_jm1
           val = ANE(i-1,j-1,iblock)
           Mat%Mat(ii)   = val
           if(Mat%Ja(ii) > 0) ii=ii+1

           nz_per_row(i_j) = ii - istart

         endif

      enddo
      enddo
      endif
   enddo
   Mat%nz = ii-1

   !----------------------
   ! Setup the row pointer 
   !----------------------
   Mat%Ia(1) = 1
   do i=2,Mat%n+1
      Mat%Ia(i) = Mat%Ia(i-1) + nz_per_row(i-1)
   enddo

#ifdef _USEESSL
   !-----------------------------------
   ! initialize the ESSL data structure 
   !-----------------------------------
   call dsrsm(0, Mat%Mat,Mat%Ja,Mat%Ia,Mat%n,Mat%nz_essl,Mat%ac,Mat%ka,Mat%n)
#endif
   deallocate(nz_per_row)

!-----------------------------------------------------------------------
!EOC

  end subroutine ConvertStencil

!***********************************************************************
!BOP
! !IROUTINE: matvec
! !INTERFACE:

  subroutine matvec(n,Mat,Y,X)

! !DESCRIPTION:
!  This routine calculates the matrix vector product, 
!
!	\begin{equation}
!		y = Mat*x
!	\end{equation}
!
!  where Mat is either a CSR or ESSL format. 
!
! !REVISION HISTORY:
!  same as module

! !INPUT PARAMETERS:

    integer (i4), intent(in) :: n ! The order of the Matrix and length of X and Y

    type (Matrix_t), intent(in)    :: Mat ! The matrix to apply 

    real (r8), intent(in),dimension(max_linear):: X    ! The operand vector X 

! !OUTPUT PARAMETERS:

    real (r8), intent(out),dimension(max_linear):: Y    ! The result vector X 

!EOP
!BOC
!-----------------------------------------------------------------------
!
!  local variables
!
!-----------------------------------------------------------------------

    real(r8) :: tmp	      ! double precision loop temporary 
    integer(i4) ::     &
	i, j, is, ie, nz, n2  ! integer loop temporary 


#ifdef _USEESSL
    !------------------------------------
    ! If we are using the ESSL format,
    ! perform the matrix vector multiply
    !------------------------------------
    call dsmmx(Mat%n,Mat%nz_essl,Mat%ac,Mat%ka,Mat%n,X,Y)
#else
    
    nz  = Mat%nz
    n2  = Mat%n
    !----------------------------------------
    ! Just a basic CSR matrix vector multiply  
    !----------------------------------------
    is = Mat%Ia(1)
    do i=1,n2
      ie = Mat%Ia(i+1)
      tmp = 0.0
         do j=is,ie-1
            tmp = tmp + Mat%Mat(j)*X(Mat%Ja(j))
         enddo
      Y(i) = tmp
      is = ie
    enddo
#endif

!EOC
!-----------------------------------------------------------------------

  end subroutine matvec

end module matrix_mod

!|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
